#!/usr/bin/env python3
"""
Generic thread pool class. Modeled after Java's ThreadPoolExecutor.
Please note that this ThreadPool does *not* fully implement the PEP 3148
ThreadPool!
"""

from threading import Thread, Lock, currentThread
from weakref import ref
import logging
from ambari_agent.ExitHelper import ExitHelper

try:
  from queue import Queue, Empty
except ImportError:
  from queue import Queue, Empty

logger = logging.getLogger(__name__)
_threadpools = set()


# Worker threads are daemonic in order to let the interpreter exit without
# an explicit shutdown of the thread pool. The following trick is necessary
# to allow worker threads to finish cleanly.
def _shutdown_all():
  for pool_ref in tuple(_threadpools):
    pool = pool_ref()
    if pool:
      pool.shutdown()


ExitHelper().register(_shutdown_all)


class ThreadPool(object):
  def __init__(
    self,
    core_threads=0,
    max_threads=20,
    keepalive=1,
    context_injector=None,
    agent_config=None,
  ):
    """
    :param core_threads: maximum number of persistent threads in the pool
    :param max_threads: maximum number of total threads in the pool
    :param thread_class: callable that creates a Thread object
    :param keepalive: seconds to keep non-core worker threads waiting
        for new tasks

    :type context_injector func
    :type agent_config AmbariConfig.AmbariConfig
    """
    self.core_threads = core_threads
    self.max_threads = max(max_threads, core_threads, 1)
    self.keepalive = keepalive
    self._queue = Queue()
    self._threads_lock = Lock()
    self._threads = set()
    self._shutdown = False

    self._job_context_injector = context_injector
    self._agent_config = agent_config

    _threadpools.add(ref(self))
    logger.info(
      "Started thread pool with %d core threads and %s maximum " "threads",
      core_threads,
      max_threads or "unlimited",
    )

  def _adjust_threadcount(self):
    self._threads_lock.acquire()
    try:
      if self.num_threads < self.max_threads:
        self._add_thread(self.num_threads < self.core_threads)
    finally:
      self._threads_lock.release()

  def _add_thread(self, core):
    t = Thread(target=self._run_jobs, args=(core,))
    t.setDaemon(True)
    t.start()
    self._threads.add(t)

  def _run_jobs(self, core):
    logger.debug("Started worker thread")
    block = True
    timeout = None
    if not core:
      block = self.keepalive > 0
      timeout = self.keepalive

    if self._job_context_injector is not None:
      self._job_context_injector(self._agent_config)

    while True:
      try:
        func, args, kwargs = self._queue.get(block, timeout)
      except Empty:
        break

      if self._shutdown:
        break

      try:
        logger.debug("Worker thread starting job %s", args[0])
        func(*args, **kwargs)
      except:
        logger.exception("Error in worker thread")

    self._threads_lock.acquire()
    self._threads.remove(currentThread())
    self._threads_lock.release()

    logger.debug("Exiting worker thread")

  @property
  def num_threads(self):
    return len(self._threads)

  def submit(self, func, *args, **kwargs):
    if self._shutdown:
      raise RuntimeError("Cannot schedule new tasks after shutdown")

    self._queue.put((func, args, kwargs))
    self._adjust_threadcount()

  def shutdown(self, wait=True):
    if self._shutdown:
      return

    logger.info("Shutting down thread pool")
    self._shutdown = True
    _threadpools.remove(ref(self))

    self._threads_lock.acquire()
    for _ in range(self.num_threads):
      self._queue.put((None, None, None))
    self._threads_lock.release()

    if wait:
      self._threads_lock.acquire()
      threads = tuple(self._threads)
      self._threads_lock.release()
      for thread in threads:
        thread.join()

  def __repr__(self):
    if self.max_threads:
      threadcount = "%d/%d" % (self.num_threads, self.max_threads)
    else:
      threadcount = "%d" % self.num_threads

    return f"<ThreadPool at {id(self):x}; threads={threadcount}>"
